package com.zwcl.common.mq;

import com.alibaba.fastjson.JSON;
import lombok.extern.slf4j.Slf4j;
import org.apache.rocketmq.spring.core.RocketMQLocalTransactionListener;
import org.apache.rocketmq.spring.core.RocketMQLocalTransactionState;
import org.apache.rocketmq.spring.support.RocketMQHeaders;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;

import java.lang.reflect.ParameterizedType;
import java.util.concurrent.TimeUnit;

/**
 * RocketMq本地事务监听器实现，监听一个txProducerGroup，不同的txProducerGroup需要自己实现不同的监听器
 * 实现了以下功能：
 * 1.MQ 在消息状态异常情况下对本地事务执行状态进行检查
 * 2.在同一producer获得新的全局唯一
 */
@Slf4j
public abstract class RocketMQLocalTransactionListenerAware<E> implements RocketMQLocalTransactionListener, InitializingBean {

    private Class messageType;

    @Autowired
    RedisTemplate<String, Object> redisTemplate;

    @Override
    public RocketMQLocalTransactionState checkLocalTransaction(Message msg) {

        MessageHeaders headers = msg.getHeaders();
        String transId = (String) headers.get(RocketMQHeaders.TRANSACTION_ID);
        Object statusObj = redisTemplate.opsForValue().get(transId);
        int status = 0;
        if (statusObj != null) {
            status = Integer.parseInt(statusObj.toString());
        }

        RocketMQLocalTransactionState transactionState;
        switch (status) {
            case 1:
                transactionState = RocketMQLocalTransactionState.COMMIT;
                break;
            case 2:
                transactionState = RocketMQLocalTransactionState.ROLLBACK;
                break;
            default:
                transactionState = RocketMQLocalTransactionState.UNKNOWN;
                break;
        }
        log.info("--- The local transaction was executed once, transactionId={}, status={}, transactionState={} ---", transId, status, transactionState);
        return transactionState;
    }


    @Override
    public RocketMQLocalTransactionState executeLocalTransaction(final Message msg, final Object arg) {
        MessageHeaders headers = msg.getHeaders();
        String transId = (String) headers.get(RocketMQHeaders.TRANSACTION_ID);

        try {
            //执行本地事务（业务逻辑代码）
            byte[] array  = (byte[])msg.getPayload();
            Object message = JSON.parseObject(new String(array), messageType);
            execute((E)message, arg);
            redisTemplate.opsForValue().set(transId, 1,30, TimeUnit.MINUTES);
            log.info("local transaction was successfully executed, transactionId={}, msg={}", transId, msg.getPayload());
            return RocketMQLocalTransactionState.COMMIT;


        } catch (Exception e) {
            log.error("execute local transactionId={}, error msg[{}]", transId, e.getMessage());
            redisTemplate.opsForValue().set(transId, 2,30, TimeUnit.MINUTES);
            return RocketMQLocalTransactionState.ROLLBACK;
        }

    }

    /**
     * 执行本地事务
     *
     * @param msg
     * @param arg
     */
    protected abstract void execute(final E msg, final Object arg);

    @Override
    public void afterPropertiesSet() {
        this.messageType = getMessageType();
        log.debug("RocketMQ messageType: {}", messageType.getName());
    }

    private Class getMessageType() {

        return (Class) ((ParameterizedType) this.getClass().getGenericSuperclass()).getActualTypeArguments()[0];
    }

}